# ===========================================================================
# QC-Bench Fine-tuning Script
# ===========================================================================
# This script implements the fine-tuning procedure for the Gemma-2 model
# on quantum computing benchmark data. The fine-tuning experiments were
# conducted on a local cluster with GPU acceleration.
#
# We use LoRA (Low-Rank Adaptation) to efficiently adapt the model to
# quantum computing knowledge without modifying all parameters, which
# enables effective fine-tuning even for smaller models.
#
# Files needed to run:
# - qc4167_qa.jsonl: Training dataset with 4,167 quantum computing QA pairs
# - qc1000_test.json: Test set with 1,000 multiple-choice questions

import os
import sys
import json
import torch
import random
import re
import shutil
from tqdm import tqdm
from torch.utils.data import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from peft import (
    LoraConfig,
    PeftModel,
    PeftConfig,
    get_peft_model,
    TaskType,
)

# Configuration
TRAIN_FILE = "qc4167_qa.jsonl"
TEST_FILE = "qc1000_test.json"
MODEL_ID = "google/gemma-2-2b-it"
OUTPUT_DIR = "ft_gemma2"

# Clean previous model if it exists
if os.path.exists(OUTPUT_DIR):
    print(f"Removing existing model at {OUTPUT_DIR}")
    shutil.rmtree(OUTPUT_DIR)

# Dataset Class
class QADataset(Dataset):
    def __init__(self, path, tokenizer, max_length=512):
        print(f"Loading dataset from {path}")
        self.examples = []

        # Load data
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    obj = json.loads(line)
                    question = obj['prompt'].strip()
                    answer = obj['completion'].strip()
                    text = f"{question}\n\nAnswer: {answer}"
                    self.examples.append(text)
                except:
                    continue

        print(f"Loaded {len(self.examples)} examples")
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.tokenizer(
            self.examples[idx],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors=None
        )

# Function to fine-tune the model
def fine_tune():
    print(f"Starting fine-tuning of {MODEL_ID}")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load base model in fp16
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        device_map="auto",
    )

    # Define LoRA configuration
    lora_config = LoraConfig(
        r=8,  # Using a smaller rank for stability
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )

    # Apply LoRA
    model = get_peft_model(model, lora_config)

    # Print trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params} ({trainable_params/total_params:.2%})")

    # Create dataset
    train_dataset = QADataset(TRAIN_FILE, tokenizer)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Set up training arguments
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=1e-4,  # Lower learning rate
        num_train_epochs=1,
        logging_steps=20,
        save_strategy="steps",
        save_steps=200,
        save_total_limit=1,
        fp16=True,
        report_to="none",
        optim="adamw_torch",
        warmup_steps=50,  # Fixed warmup steps instead of ratio
        weight_decay=0.01,
        seed=42,
    )

    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )

    # Train model
    print("Starting training...")
    trainer.train()

    # Save model
    print(f"Saving model to {OUTPUT_DIR}")
    model.save_pretrained(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)

# Function to evaluate the model
def evaluate():
    print("Starting evaluation...")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)

    # Load the config and base model
    config = PeftConfig.from_pretrained(OUTPUT_DIR)
    base_model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name_or_path,
        torch_dtype=torch.float16,
        device_map="auto",
    )

    # Load the PEFT model adapter
    model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
    model.eval()

    # Load test data
    with open(TEST_FILE, 'r', encoding='utf-8') as f:
        test_data = json.load(f)

    print(f"Loaded {len(test_data)} test questions")

    # Track correct answers
    correct = 0
    total = len(test_data)

    # Process each question
    for i, question in enumerate(tqdm(test_data, desc="Evaluating")):
        # Create prompt
        prompt = f"""You are an expert in quantum computing. Choose the correct answer to the question below.

Question: {question['question']}
A. {question['A']}
B. {question['B']}
C. {question['C']}
D. {question['D']}

Answer with only one letter: A, B, C, or D.

Answer:"""

        # Generate answer
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=5,
                do_sample=False
            )

        # Get answer
        output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract letter from response
        match = re.search(r"Answer:\s*([A-D])", output_text, re.IGNORECASE)
        if match:
            predicted_answer = match.group(1).upper()
            if predicted_answer == question["solution"]:
                correct += 1

        # Report progress
        if (i+1) % 50 == 0:
            current_acc = correct / (i+1)
            print(f"Progress: {i+1}/{total}, Current accuracy: {current_acc:.2%}")

    # Report final accuracy
    final_acc = correct / total
    print(f"\nFINAL ACCURACY: {correct}/{total} = {final_acc:.2%}")

# Main function
def main():
    print("\n=== STARTING CLEAN GEMMA-2 FINE-TUNING ===\n")

    # Verify CUDA availability
    if not torch.cuda.is_available():
        print("CUDA is not available. Please check your GPU setup.")
        sys.exit(1)

    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

    # Verify files exist
    if not os.path.exists(TRAIN_FILE):
        print(f"Training file not found: {TRAIN_FILE}")
        sys.exit(1)

    if not os.path.exists(TEST_FILE):
        print(f"Test file not found: {TEST_FILE}")
        sys.exit(1)

    # Fine-tune model
    fine_tune()

    # Evaluate model
    evaluate()

if __name__ == "__main__":
    main()